# Copyright 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. An additional grant
# of patent rights can be found in the PATENTS file in the same directory.

from __future__ import print_function
import argparse, json, os, itertools, random, shutil
import time
import re
import pdb
import copy
from pprint import pprint
import numpy as np
from sqlalchemy import null 

import question_engine as qeng

"""
Generate synthetic questions and answers for CLEVR images. Input is a single
JSON file containing ground-truth scene information for all images, and output
is a single JSON file containing all generated questions, answers, and programs.

Questions are generated by expanding templates. Each template contains a single
program template and one or more text templates, both with the same set of typed
slots; by convention <Z> = Size, <C> = Color, <M> = Material, <S> = Shape.

Program templates may contain special nodes that expand into multiple functions
during instantiation; for example a "filter" node in a program template will
expand into a combination of "filter_size", "filter_color", "filter_material",
and "filter_shape" nodes after instantiation, and a "filter_unique" node in a
template will expand into some combination of filtering nodes followed by a
"unique" node.

Templates are instantiated using depth-first search; we are looking for template
instantiations where (1) each "unique" node actually refers to a single object,
(2) constraints in the template are satisfied, and (3) the answer to the question
passes our rejection sampling heuristics.

To efficiently handle (1) and (2), we keep track of partial evaluations of the
program during each step of template expansion. This together with the use of
composite nodes in program templates (filter_unique, relate_filter_unique) allow
us to efficiently prune the search space and terminate early when we know that
(1) or (2) will be violated.
"""



parser = argparse.ArgumentParser()

# Inputs
parser.add_argument('--input_scene_file', default='../output/CLEVR_scenes.json',
        help="JSON file containing ground-truth scene information for all images " +
                 "from render_images.py")
parser.add_argument('--metadata_file', default='metadata.json',
        help="JSON file containing metadata about functions")
parser.add_argument('--synonyms_json', default='synonyms.json',
        help="JSON file defining synonyms for parameter values")
parser.add_argument('--template_dir', default='CLEVR_1.0_templates',
        help="Directory containing JSON templates for questions")

# Output
parser.add_argument('--output_questions_file',
        default='../output/CLEVR_questions.json',
        help="The output file to write containing generated questions")

# Control which and how many images to process
parser.add_argument('--scene_start_idx', default=0, type=int,
        help="The image at which to start generating questions; this allows " +
                 "question generation to be split across many workers")
parser.add_argument('--num_scenes', default=0, type=int,
        help="The number of images for which to generate questions. Setting to 0 " +
                 "generates questions for all scenes in the input file starting from " +
                 "--scene_start_idx")

# Control the number of questions per image; we will attempt to generate
# templates_per_image * instances_per_template questions per image.
parser.add_argument('--templates_per_image', default=10, type=int,
        help="The number of different templates that should be instantiated " +
                 "on each image")
parser.add_argument('--instances_per_template', default=1, type=int,
        help="The number of times each template should be instantiated on an image")

# Misc
parser.add_argument("--remove_redundant", type = float, default = 0.0, 
        help="-1.0 ~ +1.0. Filter out (>0) or add (<0) redundant filters in the question generation prodecure. Will filter out with probability, default is 0, no filtering.")
parser.add_argument('--reset_counts_every', default=250, type=int,
        help="How often to reset template and answer counts. Higher values will " +
                 "result in flatter distributions over templates and answers, but " +
                 "will result in longer runtimes.")
parser.add_argument('--verbose', action='store_true',
        help="Print more verbose output")
parser.add_argument('--time_dfs', action='store_true',
        help="Time each depth-first search; must be given with --verbose")
parser.add_argument('--profile', action='store_true',
        help="If given then run inside cProfile")
# args = parser.parse_args()


def precompute_filter_options(scene_struct, metadata, remove_redundant=False):
    # Keys are tuples (size, color, shape, material) (where some may be None)
    # and values are lists of object idxs that match the filter criterion
    attribute_map = {}

    if metadata['dataset'] == 'CLEVR-v1.0':
        attr_keys = ['size', 'color', 'material', 'shape']
    else:
        assert False, 'Unrecognized dataset'

    # Precompute masks: len of 16, ie. 16*[0,1,1,0]s
    masks = []
    for i in range(2 ** len(attr_keys)):
        mask = []
        for j in range(len(attr_keys)):
            mask.append((i // (2 ** j)) % 2)
        masks.append(mask)

    for object_idx, obj in enumerate(scene_struct['objects']):
        if metadata['dataset'] == 'CLEVR-v1.0':
            keys = [tuple(obj[k] for k in attr_keys)]


        for mask in masks:
            for key in keys:
                masked_key = []
                for a, b in zip(key, mask): 
                    '''
                    TODO: modify the key here to support obj name hierarchy (sedan -> car), 
                    by mapping the key to it super types (predifined in a dict: sedan -> car, suv -> car)
                    '''
                    if b == 1:
                        masked_key.append(a)
                    else:
                        masked_key.append(None)
                masked_key = tuple(masked_key)
                if masked_key not in attribute_map:
                    attribute_map[masked_key] = set()
                attribute_map[masked_key].add(object_idx)

                # add hierarchy here
                if masked_key[-1] is not None:
                    hypershape = metadata['_shape_hier'][masked_key[-1]]
                    masked_key = (masked_key[0], masked_key[1], masked_key[2], hypershape)
                    if masked_key not in attribute_map:
                        attribute_map[masked_key] = set()
                    attribute_map[masked_key].add(object_idx)



    if remove_redundant > 0.0:
        attribute_map = drop_redundant_filters(attribute_map, remove_redundant)

    scene_struct['_filter_options'] = attribute_map


def complete_parts(scene_struct, metadata):
    attr_keys = ['size', 'color', 'material']#, 'Partname']
    for object_idx, obj in enumerate(scene_struct['objects']):
        obj['_parts'] = {}
        hypershape = metadata['_shape_hier'][obj["shape"]]
        # for part_idx, part_name in enumerate(metadata['types']['Partname'][hypershape]):
        for part_idx, part_name in enumerate(metadata['types']['Partname'][obj["shape"]]):
            ## all parts
            # part = {k: obj[k] for k in attr_keys}
            # part['partname'] = part_name
            ## unique parts
            if part_name in obj['parts']:
                part = {'partname': part_name}
                for k in obj['parts'][part_name]:
                    part[k] = obj['parts'][part_name][k]
                obj['_parts'][part_idx] = part

def precompute_partfilter_options(scene_struct, metadata, obj_idx, remove_redundant=0.0):
    '''
    for each object (given obj_idx) in scene_struct['objects'], 
    add a ['_partfilter_options'] field that stores the part filtering info. 
    
    # Keys are tuples (size, color, shape, material) (where some may be None)
    # and values are lists of object idxs that match the filter criterion
    '''
    attribute_map = {}

    if metadata['dataset'] == 'CLEVR-v1.0':
        attr_keys = ['size', 'color', 'material', 'partname']
    else:
        assert False, 'Unrecognized dataset'

    # Precompute masks: len of 16, ie. 16*[0,1,1,0]s
    masks = []
    for i in range(2 ** len(attr_keys)):
        mask = []
        for j in range(len(attr_keys)):
            mask.append((i // (2 ** j)) % 2)
        masks.append(mask)
    
    assert '_parts' in scene_struct['objects'][obj_idx]
    for part_idx, part in scene_struct['objects'][obj_idx]['_parts'].items():
        keys = [tuple(part[k] for k in attr_keys)]

        for mask in masks:
            for key in keys:
                masked_key = []
                for a, b in zip(key, mask): 
                    '''
                    TODO: modify the key here to support obj name hierarchy (sedan -> car), 
                    by mapping the key to it super types (predifined in a dict: sedan -> car, suv -> car)
                    '''
                    if b == 1:
                        masked_key.append(a)
                    else:
                        masked_key.append(None)
                masked_key = tuple(masked_key)
                if masked_key not in attribute_map:
                    attribute_map[masked_key] = set()
                attribute_map[masked_key].add(part_idx)
                # add hierarchy here
                if masked_key[-1] is not None:
                    special_words = ['right', 'left', 'front', 'back', 'center', 'mid', 's']
                    super_partname = '_'.join([a for a in masked_key[-1].split('_') if a not in special_words])
                    if super_partname != masked_key[-1]:
                        masked_key = (masked_key[0], masked_key[1], masked_key[2], super_partname)
                        if masked_key not in attribute_map:
                            attribute_map[masked_key] = set()
                        attribute_map[masked_key].add(part_idx)
                

    attribute_map.pop((None, None, None, None)) # does not allow empty filter for parts
    if remove_redundant:
        attribute_map = drop_redundant_filters(attribute_map, remove_redundant)

    scene_struct['objects'][obj_idx]['_partfilter_options'] = attribute_map
    
def subsumes(k1, k2):
    """
    Let k1 subsume k2 if 2 conditions are met:
    1. for all indices in k2 that are not None, k1 is equal to k2
    2. for some index in k2 that is None, k1 is not None
    i.e. k1 is more restrictive than k2 
    """
    assert(len(k1) == len(k2))
    any_different = False
    # go up to the last element, which is the object type 
    for i in range(len(k2)):
        # condition 1 
        if k2[i] is not None and k1[i] != k2[i]:
            return False
        # condition 2 
        if k2[i] is None and k1[i] is not None:
            any_different = True
    return any_different 

def drop_redundant_filters(attribute_map, p_remove):            
    to_drop = []
    for k1, denot1 in attribute_map.items():
        for k2, denot2 in attribute_map.items(): 
            if k1 == k2:
                continue
            if subsumes(k1, k2):
                # sanity check 
                assert(len(denot1) <= len(denot2)) 
                if len(denot1) == len(denot2):
                    do_remove = np.random.choice([True, False], p=[p_remove, 1-p_remove])
                    if do_remove:
                        to_drop.append(k1)

    # use None as sentinel for later 
    # new_attribute_map = {k:v if k not in to_drop else None for k, v in attribute_map.items()}
    new_attribute_map = {k:v for k, v in attribute_map.items() if k not in to_drop}
    return new_attribute_map
    
def find_partfilter_options(objectpart_idxs, scene_struct, metadata, remove_redundant=0.0):
    # objectpart_idxs: dicts{obj_id: [part_ids]}
    # Keys are tuples (size, color, Partname, material) (where some may be None)
    # and values are dicts{obj_id: [part_ids]} that match the filter criterion

    part_idxs = {}
    for idx in objectpart_idxs:
        
        obj_idx, part_idx = [int(a) for a in idx.split('_')]
        if obj_idx not in part_idxs:
            part_idxs[obj_idx] = set()
        part_idxs[obj_idx].add(part_idx)
        
    attribute_map = {}
    for obj_idx in part_idxs:
        if '_partfilter_options' not in scene_struct['objects'][obj_idx]:
            precompute_partfilter_options(scene_struct, metadata, obj_idx, remove_redundant=remove_redundant)

        part_idx_list = set(part_idxs[obj_idx])
        for k, vs in scene_struct['objects'][obj_idx]['_partfilter_options'].items():
            res = list(part_idx_list & vs)
            if k not in attribute_map:
                attribute_map[k] = {}
            if obj_idx not in attribute_map[k]:
                attribute_map[k][obj_idx] = set()
            attribute_map[k][obj_idx].update(res)

    return attribute_map

def find_filter_options(object_idxs, scene_struct, metadata, remove_redundant=0.0):
    # Keys are tuples (size, color, shape, material) (where some may be None)
    # and values are lists of object idxs that match the filter criterion

    if '_filter_options' not in scene_struct:
        precompute_filter_options(scene_struct, metadata, remove_redundant=remove_redundant)

    attribute_map = {}
    object_idxs = set(object_idxs)
    for k, vs in scene_struct['_filter_options'].items():
        attribute_map[k] = list(object_idxs & vs)
    return attribute_map


def add_empty_filter_options(attribute_map, metadata, num_to_add):
    # Add some filtering criterion that do NOT correspond to objects

    if metadata['dataset'] == 'CLEVR-v1.0':
        attr_keys = ['Size', 'Color', 'Material', 'Shape']
        # attr_keys = ['Size', 'Color', 'Shape']
    else:
        assert False, 'Unrecognized dataset'
    
    attr_vals = [metadata['types'][t] + [None] for t in attr_keys]
    if '_filter_options' in metadata:
        attr_vals = metadata['_filter_options']

    target_size = len(attribute_map) + num_to_add
    while len(attribute_map) < target_size:
        k = tuple([random.choice(v) for v in attr_vals])
        if k not in attribute_map:
            attribute_map[k] = []


def find_relate_filter_options(object_idx, scene_struct, metadata,
        unique=False, include_zero=False, trivial_frac=0.1, remove_redundant=0.0):
    options = {}
    if '_filter_options' not in scene_struct:
        precompute_filter_options(scene_struct, metadata, remove_redundant=remove_redundant)

    # TODO: Right now this is only looking for nontrivial combinations; in some
    # cases I may want to add trivial combinations, either where the intersection
    # is empty or where the intersection is equal to the filtering output.
    trivial_options = {}
    for relationship in scene_struct['relationships']:
        related = set(scene_struct['relationships'][relationship][object_idx])
        for filters, filtered in scene_struct['_filter_options'].items():
            intersection = related & filtered
            trivial = (intersection == filtered)
            if unique and len(intersection) != 1: continue
            if not include_zero and len(intersection) == 0: continue
            if trivial:
                trivial_options[(relationship, filters)] = sorted(list(intersection))
            else:
                options[(relationship, filters)] = sorted(list(intersection))

    N, f = len(options), trivial_frac
    num_trivial = int(round(N * f / (1 - f)))
    trivial_options = list(trivial_options.items())
    # trivial_options = sorted(trivial_options, key = lambda x: str(x))
    random.shuffle(trivial_options)
    for k, v in trivial_options[:num_trivial]:
        options[k] = v

    return options


def node_shallow_copy(node):
    new_node = {
        'type': node['type'],
        'inputs': node['inputs'],
    }
    if 'side_inputs' in node:
        new_node['side_inputs'] = node['side_inputs']
    return new_node


def other_heuristic(text, param_vals):
    """
    Post-processing heuristic to handle the word "other"
    """
    if ' other ' not in text and ' another ' not in text:
        return text
    target_keys = {
        '<Z>',  '<C>',  '<M>',  '<S>',
        '<Z2>', '<C2>', '<M2>', '<S2>',
    }
    if param_vals.keys() != target_keys:
        return text
    key_pairs = [
        ('<Z>', '<Z2>'),
        ('<C>', '<C2>'),
        ('<M>', '<M2>'),
        ('<S>', '<S2>'),
    ]
    remove_other = False
    for k1, k2 in key_pairs:
        v1 = param_vals.get(k1, None)
        v2 = param_vals.get(k2, None)
        if v1 != '' and v2 != '' and v1 != v2:
            print('other has got to go! %s = %s but %s = %s'
                        % (k1, v1, k2, v2))
            remove_other = True
            break
    if remove_other:
        if ' other ' in text:
            text = text.replace(' other ', ' ')
        if ' another ' in text:
            text = text.replace(' another ', ' a ')
    return text


def get_question_hash(image_idx, scene_struct, _question, text):
    """
    get a question hash that can be compared whether or not we have redundant 
    descriptions in referring expressions. Should be based on:
    - the query object 
    - the query part 
    - the query attribute 
    """
    question = copy.deepcopy(_question)
    obj_name, part_name, query_name = None, None, None
    if question == "ERROR": 
        return "ERROR"
    query_start = question[-1]
    query_type = query_start['type']
    parents = query_start['inputs']

    ops = []
    while len(parents) > 0:
        par_idx = parents.pop()
        filter_step = question[par_idx]

        if filter_step['type'] == 'unique': 
            output = filter_step['_output']
            if type(output) == str and "_" in output:
                object_idx, part_idx = [int(x) for x in output.split("_")]
            else:
                object_idx, part_idx = output, None

            obj = scene_struct['objects'][object_idx]
            obj_name = obj['shape']
            obj_name = f"{object_idx}_{obj_name}"
            ops.append(obj_name)
            if part_idx is not None:
                part_name = obj['_parts'][part_idx]['partname']
                ops.append(part_name)
            
            parents += filter_step['inputs']

    # assert(obj_name is not None)
    # assert(part_name is not None)
    # assert(query_name is not None)
    # question_hash = f"{image_idx}_{obj_name}_{part_name}_{query_name}"

    
    question_hash = "_".join([query_type] + ops)
    question_hash = f"{image_idx}_{question_hash}"
    return question_hash 

def get_equivalent_filter(k, keys, options):
    # k format is (size, color, shape, material)
    # equivalent filters refer to the same shape and material 
    equivalent_filters = [k2 for k2 in keys if k2[-1] == k[-1] and k2[-2] == k[-2] and options[k2] is not None]
    if len(equivalent_filters) == 0:
        return None
    return random.choice(equivalent_filters)


def instantiate_templates_dfs(scene_struct, 
                              template, 
                              metadata, 
                              answer_counts,
                              synonyms, 
                              max_instances=None, 
                              remove_redundant=0.0,
                              verbose=False):

    param_name_to_type = {p['name']: p['type'] for p in template['params']} 
    
    null_params = []
    for constraint in template['constraints']:
        if constraint['type'] == 'NULL':
            p = constraint['params'][0]
            null_params.append(p)

    initial_state = {
        'nodes': [node_shallow_copy(template['nodes'][0])],
        'vals': {},
        'input_map': {0: 0},
        'next_template_node': 1,
    }
    states = [initial_state]
    final_states = []
    reject_count = 0
    while states:
        state = states.pop()

        # Check to make sure the current state is valid
        q = {'nodes': state['nodes']}
        outputs = qeng.answer_question(q, metadata, scene_struct, all_outputs=True)
        answer = outputs[-1] #len(outputs) is equal to len(state['nodes'], outputs contain answers after each node)

        if answer == '__INVALID__': continue

        # Check to make sure constraints are satisfied for the current state
        skip_state = False
        for constraint in template['constraints']:
            if constraint['type'] == 'NEQ':
                p1, p2 = constraint['params']
                v1, v2 = state['vals'].get(p1), state['vals'].get(p2)
                if v1 is not None and v2 is not None and v1 != v2:
                    if verbose:
                        print('skipping due to NEQ constraint')
                        print(constraint)
                        print(state['vals'])
                    skip_state = True
                    break
            elif constraint['type'] == 'NULL':
                p = constraint['params'][0]
                p_type = param_name_to_type[p]
                v = state['vals'].get(p)
                if v is not None:
                    skip = False
                    if p_type == 'Shape' and v != 'thing': skip = True
                    if p_type == 'Partname' and v != 'part': skip = True
                    if p_type not in ['Shape', 'Partname'] and v != '': skip = True
                    if skip:
                        if verbose:
                            print('skipping due to NULL constraint')
                            print(constraint)
                            print(state['vals'])
                        skip_state = True
                        break
            elif constraint['type'] == 'OUT_NEQ':
                i, j = constraint['params']
                i = state['input_map'].get(i, None)
                j = state['input_map'].get(j, None)
                if i is not None and j is not None and outputs[i] == outputs[j]:
                    if verbose:
                        print('skipping due to OUT_NEQ constraint')
                        print(outputs[i])
                        print(outputs[j])
                    skip_state = True
                    break
            else:
                assert False, 'Unrecognized constraint type "%s"' % constraint['type']

        if skip_state:
            continue

        # We have already checked to make sure the answer is valid, so if we have
        # processed all the nodes in the template then the current state is a valid
        # question, so add it if it passes our rejection sampling tests.
        if state['next_template_node'] == len(template['nodes']):
            # Use our rejection sampling heuristics to decide whether we should
            # keep this template instantiation
            if reject_count % 10000 == 0 and reject_count>0:
                print(reject_count)
                if reject_count % 100000 == 0:
                    print(answer_counts, template['text'][0])
            cur_answer_count = answer_counts[answer]
            answer_counts_sorted = sorted(answer_counts.values())
            median_count = answer_counts_sorted[len(answer_counts_sorted) // 2]
            median_count = max(median_count, 5)
            # if 'material' not in template['text'][0]:
            if cur_answer_count > 1.1 * (answer_counts_sorted[-2]+1):
                if verbose: print('skipping due to second count', len(states), answer, answer_counts)
                reject_count += 1
                continue
            if cur_answer_count > 5.0 * (median_count+1):
                if verbose: print('skipping due to median')
                reject_count += 1
                continue

            # If the template contains a raw relate node then we need to check for
            # degeneracy at the end
            has_relate = any(n['type'] == 'relate' for n in template['nodes'])
            if has_relate:
                degen = qeng.is_degenerate(q, metadata, scene_struct, answer=answer,
                                                                     verbose=verbose)
                # if remove_redundant < 0, then keep degenerated questions with prob=-remove_redundant
                degen &= (np.random.random() >= -remove_redundant)
                if degen:
                    reject_count += 1
                    continue

            answer_counts[answer] += 1
            state['answer'] = answer
            final_states.append(state)
            if max_instances is not None and len(final_states) == max_instances:
                break
            continue

        # Otherwise fetch the next node from the template
        # Make a shallow copy so cached _outputs don't leak ... this is very nasty
        next_node = template['nodes'][state['next_template_node']]
        next_node = node_shallow_copy(next_node)

        special_nodes = {
                'filter_unique', 'filter_count', 'filter_exist', 'filter',
                'relate_filter', 'relate_filter_unique', 'relate_filter_count',
                'relate_filter_exist',
                'partfilter_unique', 'partfilter'
        }
        
        if next_node['type'] in special_nodes:
            part_flag = ''
            if next_node['type'].startswith('relate_filter'):
                unique = (next_node['type'] == 'relate_filter_unique')
                include_zero = (next_node['type'] == 'relate_filter_count'
                                                or next_node['type'] == 'relate_filter_exist')
                filter_options = find_relate_filter_options(answer, scene_struct, metadata,
                                                        unique=unique, include_zero=include_zero,
                                                        remove_redundant=remove_redundant)
            else:
                if next_node['type'].startswith('part'):
                    part_flag = 'part'
                    filter_options = find_partfilter_options(answer, scene_struct, metadata, 
                                                             remove_redundant=remove_redundant)       
                    unified_node_type = next_node['type'][4:]
                else:
                    filter_options = find_filter_options(answer, scene_struct, metadata, 
                                                         remove_redundant=remove_redundant)       
                    unified_node_type = next_node['type']
            
                if unified_node_type == 'filter':
                    # Remove null filter
                    filter_options.pop((None, None, None, None), None)
                if unified_node_type == 'filter_unique':
                    # Get rid of all filter options that don't result in a single object
                    if part_flag == '':
                        filter_options = {k: v for k, v in filter_options.items()
                                                       if len(v) == 1}
                    else:
                        filter_options = {k: v for k, v in filter_options.items()
                                                    if len(v) == 1 and len(list(v.values())[0]) == 1 }
                else:
                    # Add some filter options that do NOT correspond to the scene
                    if unified_node_type == 'filter_exist':
                        # For filter_exist we want an equal number that do and don't
                        num_to_add = len(filter_options)
                    elif unified_node_type == 'filter_count' or unified_node_type == 'filter':
                        # For filter_count add nulls equal to the number of singletons
                        num_to_add = sum(1 for k, v in filter_options.items() if len(v) == 1)
                    add_empty_filter_options(filter_options, metadata, num_to_add)

            filter_option_keys = list(filter_options.keys())
            # filter_option_keys = sorted(filter_option_keys, key=lambda x: [str(y) for y in x])
            random.shuffle(filter_option_keys)

            for k in filter_option_keys:
                #if filter_options[k] is None:
                #    k = get_equivalent_filter(k, filter_option_keys, filter_options)
                #    if k is None:
                #        # we cannot find any equivalent non-redundant filters, so error out 
                #        # for the whole example 
                #        states.append(None)
                #        continue 
                #        # pdb.set_trace()



                new_nodes = []
                cur_next_vals = {k: v for k, v in state['vals'].items()}
                next_input = state['input_map'][next_node['inputs'][0]]
                filter_side_inputs = next_node['side_inputs']
                if next_node['type'].startswith('relate'):
                    param_name = next_node['side_inputs'][0] # First one should be relate
                    filter_side_inputs = next_node['side_inputs'][1:]
                    param_type = param_name_to_type[param_name]
                    assert param_type == 'Relation'
                    param_val = k[0]
                    k = k[1]
                    new_nodes.append({
                        'type': 'relate',
                        'inputs': [next_input],
                        'side_inputs': [param_val],
                    })
                    cur_next_vals[param_name] = param_val
                    next_input = len(state['nodes']) + len(new_nodes) - 1
                for param_name, param_val in zip(filter_side_inputs, k):
                    # filter_side_inputs: ['<Z>', '<C>', '<M>', '<S>']
                    # k: ('large', 'brown', None, None)
                    param_type = param_name_to_type[param_name]
                    filter_type = part_flag + 'filter_%s' % param_type.lower()
                    if param_val is not None:
                        new_nodes.append({
                            'type': filter_type,
                            'inputs': [next_input],
                            'side_inputs': [param_val],
                        })
                        cur_next_vals[param_name] = param_val
                        next_input = len(state['nodes']) + len(new_nodes) - 1
                    elif param_val is None:                                
                        if metadata['dataset'] == 'CLEVR-v1.0' and param_type == 'Shape':
                            param_val = 'thing'
                        elif metadata['dataset'] == 'CLEVR-v1.0' and param_type == 'Partname':
                            param_val = 'part'
                        else:
                            param_val = ''
                        cur_next_vals[param_name] = param_val
                    
                # add redundant modules here
                to_add_redundant = [param_name for param_name, param_val in zip(filter_side_inputs, k) if param_val is None]
                to_add_redundant = [a for a in to_add_redundant if a not in null_params]
                # if remove_redundant < 0 (-1~0), then keep to_add_redundant with p=(-remove_redundant)
                to_add_redundant = list(filter(lambda a: np.random.random() <= -remove_redundant, to_add_redundant))
                _outputs = qeng.answer_question({'nodes':state['nodes']+new_nodes}, metadata, scene_struct, all_outputs=True)
                def check_common_attr(objs, param_type):
                    attrs = []
                    for obj in objs:
                        if type(obj)==int:
                            attr = scene_struct['objects'][obj][param_type]
                        else:
                            assert('_' in obj)
                            obj_id, part_id = [int(a) for a in obj.split('_')]
                            part_name = metadata['types']['Partname'][scene_struct['objects'][obj_id]["shape"]][part_id]
                            attr = scene_struct['objects'][obj_id]['parts'][part_name][param_type]
                        attrs.append(attr)  
                    if len(set(attrs)) == 1:
                        return attrs[0]
                    else:
                        return None
                for param_name in to_add_redundant:
                    param_type = param_name_to_type[param_name]
                    param_val = check_common_attr(_outputs[-1], param_type.lower())
                    if param_val is not None:
                        filter_type = part_flag + 'filter_%s' % param_type.lower()
                        if param_val is not None:
                            new_nodes.append({
                                'type': filter_type,
                                'inputs': [next_input],
                                'side_inputs': [param_val],
                            })
                            cur_next_vals[param_name] = param_val
                            next_input = len(state['nodes']) + len(new_nodes) - 1
                
                input_map = {k: v for k, v in state['input_map'].items()}
                extra_type = None
                if next_node['type'].endswith('unique'):
                    extra_type = 'unique'
                if next_node['type'].endswith('count'):
                    extra_type = 'count'
                if next_node['type'].endswith('exist'):
                    extra_type = 'exist'
                if extra_type is not None:
                    new_nodes.append({
                        'type': extra_type,
                        'inputs': [input_map[next_node['inputs'][0]] + len(new_nodes)],
                    })
                input_map[state['next_template_node']] = len(state['nodes']) + len(new_nodes) - 1
                
                
                states.append({
                    'nodes': state['nodes'] + new_nodes,
                    'vals': cur_next_vals,
                    'input_map': input_map,
                    'next_template_node': state['next_template_node'] + 1,
                })
        
        elif 'side_inputs' in next_node:
            # If the next node has template parameters, expand them out
            # TODO: Generalize this to work for nodes with more than one side input
            assert len(next_node['side_inputs']) == 1, 'NOT IMPLEMENTED'

            # Use metadata to figure out domain of valid values for this parameter.
            # Iterate over the values in a random order; then it is safe to bail
            # from the DFS as soon as we find the desired number of valid template
            # instantiations.
            param_name = next_node['side_inputs'][0]
            param_type = param_name_to_type[param_name]
            param_vals = metadata['types'][param_type][:]
            # param_vals = sorted(param_vals, key= lambda x: str(x))
            random.shuffle(param_vals)
            for val in param_vals:
                input_map = {k: v for k, v in state['input_map'].items()}
                input_map[state['next_template_node']] = len(state['nodes'])
                cur_next_node = {
                    'type': next_node['type'],
                    'inputs': [input_map[idx] for idx in next_node['inputs']],
                    'side_inputs': [val],
                }
                cur_next_vals = {k: v for k, v in state['vals'].items()}
                cur_next_vals[param_name] = val

                states.append({
                    'nodes': state['nodes'] + [cur_next_node],
                    'vals': cur_next_vals,
                    'input_map': input_map,
                    'next_template_node': state['next_template_node'] + 1,
                })
        else:
            if next_node['type'] == 'object2part':
                complete_parts(scene_struct, metadata)
            input_map = {k: v for k, v in state['input_map'].items()}
            input_map[state['next_template_node']] = len(state['nodes'])
            next_node = {
                'type': next_node['type'],
                'inputs': [input_map[idx] for idx in next_node['inputs']],
            }
            
            states.append({
                'nodes': state['nodes'] + [next_node],
                'vals': state['vals'],
                'input_map': input_map,
                'next_template_node': state['next_template_node'] + 1,
            })
            

    # Actually instantiate the template with the solutions we've found
    text_questions, structured_questions, answers, box_token_mappings = [], [], [], []
    for state in final_states:
        structured_questions.append(state['nodes'])
        answer = post_process_part_name(state['answer'])
        if answer in metadata['types']['Shapename']:
            answer = metadata['types']['Shapename'][answer]
        answers.append(answer)
        text = random.choice(template['text'])
        # for name, val in state['vals'].items():
        #     if val in synonyms:
        #         val = random.choice(synonyms[val])
        #     elif val in metadata['types']['Shapename']:
        #         val = metadata['types']['Shapename'][val]
        #     else:
        #         val = post_process_part_name(val)
        #     text = text.replace(name, val)
        #     text = ' '.join(text.split())
        text = replace_optionals(text)
        text = ' '.join(text.split())
        text = other_heuristic(text, state['vals'])
        text, box_token_mapping = get_box_token_mapping(state, metadata, template, text, synonyms)
        text_questions.append(text)
        box_token_mappings.append(box_token_mapping)

    return text_questions, structured_questions, answers, box_token_mappings

PARAM_REG = re.compile(r"<.\d?>") 
OTHER_QUES = re.compile(r"(?:Is|Are) there ((?:anything else|any other thing)s?) that")
OTHER_COUNT_QUES = re.compile(r"(?:How many|What number of) (other (?:thing|object)s?)")       
def get_box_token_mapping(state, metadata, template, text, synonyms):
    box_token_mapping = {}
    # find the output objects (output_objs) for current node (super_node_idx, node_idx)
    node_outputs = {}
    for super_node_idx in state['input_map'].keys():
        if super_node_idx == 0:
            continue
        
        output_objs = []
        curr_node_idx = state['input_map'][super_node_idx]
        last_node_idx = state['input_map'][super_node_idx-1]
        node_idx = None
        for node_idx in range(curr_node_idx, last_node_idx, -1):
            curr_node = state['nodes'][node_idx]
            if metadata['_functions_by_name'][curr_node['type']]['output'] in ['Object', 'ObjectSet', 'Part', 'PartSet']:
                output_objs = curr_node['_output']
                break
        if type(output_objs) != list:
            output_objs = [output_objs]
        if node_idx is None:
            inp_super_node_idx = template['nodes'][super_node_idx]['inputs'][0]
            output_objs = node_outputs.get(inp_super_node_idx, {'output_objs':[]})['output_objs']
        node_outputs[super_node_idx] = {'node_idx': node_idx, 'output_objs': output_objs, 'tokens':[]}
        #if curr_node_idx == last_node_idx:
        #    pdb.set_trace()
    for super_node_idx in state['input_map'].keys():
        if 'same' in template['nodes'][super_node_idx]['type']:
            if template['nodes'][super_node_idx+1]['inputs'][0] == super_node_idx and 'filter' in template['nodes'][super_node_idx+1]['type']:
                node_outputs[super_node_idx]['output_objs'] = node_outputs[super_node_idx+1]['output_objs']
    
    # find the super_node_idx for each param: {'<M>': super_node_idx}
    param_snodeidx_map = {}
    for super_node_idx in state['input_map'].keys():
        if 'side_inputs' not in template['nodes'][super_node_idx]:
            continue
        for side_input in template['nodes'][super_node_idx]['side_inputs']:
            param_snodeidx_map[side_input] = super_node_idx
    
    sorted_side_inputs = re.findall(PARAM_REG, text) 
    
    # find token idx for each side_input, and append result to node_outputs
    for side_input in sorted_side_inputs:       
        super_node_idx = param_snodeidx_map[side_input]
        # find the token idx for current node
        val = state['vals'][side_input]
        if val in synonyms:
            val = random.choice(synonyms[val])
        elif val in metadata['types']['Shapename']:
            val = metadata['types']['Shapename'][val]
        else:
            val = post_process_part_name(val)
        
        token_start_idx = text.find(side_input)
        token_end_idx = token_start_idx+len(val)
        
        text = text.replace(side_input, val)
        text = ' '.join(text.split())
        
        last_word = text[:token_start_idx].split()[-1]
        # add determinant
        # if last_word in ['a', 'the', 'other', 'another'] and len(node_outputs[super_node_idx]['tokens'])==0:
        #     if token_end_idx == token_start_idx:
        #         token_end_idx -= 1
        #     token_start_idx = token_start_idx-len(last_word)-1
        if text[token_end_idx]=='s': # hacky way to handle 's' 'es', eg bikes
            token_end_idx += 1
        elif text[token_end_idx]=='e':
            token_end_idx += 2
        if token_end_idx != token_start_idx:
            node_outputs[super_node_idx]['tokens'].append([token_start_idx, token_end_idx])     
        
    
    # only filter**, sam** node can produce obj output
    ## TODO: maybe also relate?
    for super_node_idx, o in node_outputs.items():
        super_node_type = template['nodes'][super_node_idx]['type']
        if 'filter' not in super_node_type and 'same' not in super_node_type:
            continue
        for obj in o['output_objs']:
            if obj not in box_token_mapping:
                box_token_mapping[obj] = []
            box_token_mapping[obj].extend(o['tokens'])
    
    # for nodes that does not have side_inputs (eg same, and, or)
    for super_node_idx, super_node in enumerate(template['nodes']):
        node_type = super_node['type']
        ## and, or do not need special handling, e.g., thing, to the right of, to the left of
        ## same
        if 'same' in node_type:
            same_str = ' '.join(node_type.split('_'))
            output_objs = node_outputs[super_node_idx]['output_objs']
            token_start_idx = text.find(same_str)
            token_end_idx = token_start_idx + len(same_str)
            for obj in output_objs:
                box_token_mapping[obj].append([token_start_idx, token_end_idx])
            for other_re in [OTHER_QUES, OTHER_COUNT_QUES]:
                other_match = re.match(other_re, text)
                if other_match is not None:
                    for obj in output_objs:
                        box_token_mapping[obj].append(list(other_match.span(1)))
            
    
    # remove repeated entries
    # box_token_mapping = {obj: list(set(b)) for obj, b in box_token_mapping.items()}
        
    # print(text)
    # for obj, maps in box_token_mapping.items():
    #     print(obj, maps, [text[m[0]:m[1]] for m in maps])  
    
    return text, box_token_mapping

def replace_optionals(s):
    """
    Each substring of s that is surrounded in square brackets is treated as
    optional and is removed with probability 0.5. For example the string

    "A [aa] B [bb]"

    could become any of

    "A aa B bb"
    "A  B bb"
    "A aa B "
    "A  B "

    with probability 1/4.
    """
    pat = re.compile(r'\[([^\[]*)\]')

    while True:
        match = re.search(pat, s)
        if not match:
            break
        i0 = match.start()
        i1 = match.end()
        if random.random() > 0.5:
            s = s[:i0] + match.groups()[0] + s[i1:]
        else:
            s = s[:i0] + s[i1:]
    return s

def post_process_part_name(s):
    '''
    make door_right into right door (right, left, front, back, center, _s)
    '''
    if not isinstance(s, str):
        return s
    special_words = ['right', 'left', 'front', 'back', 'center', 'mid']
    s = s.split('_')
    if s[-1] == 's':
        s.pop(-1)
    while s[-1] in special_words:
        a = s.pop(-1)
        s.insert(0, a)
    s = ' '.join(s)
    return s


def main(args):
    with open(args.metadata_file, 'r') as f:
        metadata = json.load(f)
        dataset = metadata['dataset']
        if dataset != 'CLEVR-v1.0':
            raise ValueError('Unrecognized dataset "%s"' % dataset)
    
    functions_by_name = {}
    for f in metadata['functions']:
        functions_by_name[f['name']] = f
    metadata['_functions_by_name'] = functions_by_name
    
    metadata['_shape_hier'] = {}
    for hypername in metadata['types']['Shape']:
        for shapename in metadata['types']['Shape'][hypername]:
            metadata['_shape_hier'][shapename] = hypername
            
    metadata['types']['Shape'] = list(metadata['_shape_hier'].keys())
    
    metadata['types']['Partname'] = json.load(open(metadata['types']['Partname'], 'r'))

    # Load templates from disk
    # Key is (filename, file_idx)
    num_loaded_templates = 0
    templates = {}
    for fn in os.listdir(args.template_dir):
        if not fn.endswith('.json'): continue
        with open(os.path.join(args.template_dir, fn), 'r') as f:
            base = os.path.splitext(fn)[0]
            for i, template in enumerate(json.load(f)):
                num_loaded_templates += 1
                key = (fn, i)
                templates[key] = template
    print('Read %d templates from disk' % num_loaded_templates)

    def reset_counts():
        # Maps a template (filename, index) to the number of questions we have
        # so far using that template
        template_counts = {}
        # Maps a template (filename, index) to a dict mapping the answer to the
        # number of questions so far of that template type with that answer
        template_answer_counts = {}
        node_type_to_dtype = {n['name']: n['output'] for n in metadata['functions']}
        for key, template in templates.items():
            template_counts[key[:2]] = 0
            final_node_type = template['nodes'][-1]['type']
            final_dtype = node_type_to_dtype[final_node_type]
            answers = metadata['types'][final_dtype]
            if type(answers) == dict: #Partname or Shape
                res = []
                for _parts in answers.values():
                    res.extend(_parts)
                answers = list(set(res))
            if final_dtype == 'Bool':
                answers = [True, False]
            if final_dtype == 'Integer':
                if metadata['dataset'] == 'CLEVR-v1.0':
                    answers = list(range(0, 11))
            template_answer_counts[key[:2]] = {}
            for a in answers:
                template_answer_counts[key[:2]][a] = 0
        return template_counts, template_answer_counts

    template_counts, template_answer_counts = reset_counts()

    # Read file containing input scenes
    all_scenes = []
    with open(args.input_scene_file, 'r') as f:
        scene_data = json.load(f)
        all_scenes = scene_data['scenes']
        scene_info = scene_data['info']
    begin = args.scene_start_idx
    
    if args.num_scenes > 0:
        end = args.scene_start_idx + args.num_scenes
        all_scenes = all_scenes[begin:end]
    else:
        all_scenes = all_scenes[begin:]

    # Read synonyms file
    with open(args.synonyms_json, 'r') as f:
        synonyms = json.load(f)
        

    questions = []
    scene_count = 0
    for i, scene in enumerate(all_scenes):
        scene_fn = scene['image_filename']
        scene_struct = scene
        print('starting image %s (%d / %d)'
                    % (scene_fn, i + 1, len(all_scenes)))

        if scene_count % args.reset_counts_every == 0:
            print('resetting counts')
            template_counts, template_answer_counts = reset_counts()
        scene_count += 1

        # Order templates by the number of questions we have so far for those
        # templates. This is a simple heuristic to give a flat distribution over
        # templates.
        templates_items = list(templates.items())
        templates_items = sorted(templates_items,
                                                key=lambda x: template_counts[x[0][:2]])
        num_instantiated = 0
        
        for (fn, idx), template in templates_items:
            if args.verbose:
                print('trying template ', fn, idx)
            if args.time_dfs and args.verbose:
                tic = time.time()
            
            random.seed(len(questions))
            np.random.seed(len(questions))
            ts, qs, ans, bmaps = instantiate_templates_dfs(
                                            scene_struct,
                                            template,
                                            metadata,
                                            template_answer_counts[(fn, idx)],
                                            synonyms,
                                            max_instances=args.instances_per_template,
                                            remove_redundant=args.remove_redundant,
                                            verbose=False)
            if args.time_dfs and args.verbose:
                toc = time.time()
                print('that took ', toc - tic)
            image_index = int(os.path.splitext(scene_fn)[0].split('_')[-1])
            for t, q, a, bmap in zip(ts, qs, ans, bmaps):
                question_hash = get_question_hash(image_index, scene_struct, q, t)
                questions.append({
                    'split': scene_info['split'],
                    'image_filename': scene_fn,
                    'image_index': image_index,
                    'image': os.path.splitext(scene_fn)[0],
                    'question': t,
                    'program': q,
                    'answer': a,
                    'obj_map': bmap,
                    'template_filename': fn,
                    'question_family_index': idx,
                    'question_hash': question_hash, 
                    'question_index': len(questions),
                })
            if len(ts) > 0:
                if args.verbose:
                    print('got one!')
                num_instantiated += 1
                template_counts[(fn, idx)] += 1
            elif args.verbose:
                print('did not get any =(')
            if num_instantiated >= args.templates_per_image:
                break

    # Change "side_inputs" to "value_inputs" in all functions of all functional
    # programs. My original name for these was "side_inputs" but I decided to
    # change the name to "value_inputs" for the public CLEVR release. I should
    # probably go through all question generation code and templates and rename,
    # but that could be tricky and take a while, so instead I'll just do it here.
    # To further complicate things, originally functions without value inputs did
    # not have a "side_inputs" field at all, and I'm pretty sure this fact is used
    # in some of the code above; however in the public CLEVR release all functions
    # have a "value_inputs" field, and it's an empty list for functions that take
    # no value inputs. Again this should probably be refactored, but the quick and
    # dirty solution is to keep the code above as-is, but here make "value_inputs"
    # an empty list for those functions that do not have "side_inputs". Gross.
    for q in questions:
        if q['question_hash'] == "ERROR":
            continue
        for f in q['program']:
            if 'side_inputs' in f:
                f['value_inputs'] = f['side_inputs']
                del f['side_inputs']
            else:
                f['value_inputs'] = []


    # sort by question hash 
    # questions = sorted(questions, key=lambda x: x['question_hash'])
    # filter out the Nones 
    exclude = "(<Z>)|(<C>)|(<M>)|(<S>)"
    questions = [x for x in questions if re.search(exclude, x['question']) is None]

    with open(args.output_questions_file, 'w') as f:
        print('Writing output to %s' % args.output_questions_file)
        json.dump({
                'info': scene_info,
                'questions': questions,
            }, f, indent=2)


if __name__ == '__main__':
    args = parser.parse_args()
    if args.profile:
        import cProfile
        cProfile.run('main(args)')
    else:
        main(args)

